{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/drive')\n",
    "%tensorflow_version 2.x\n",
    "import os\n",
    "os.chdir('/content/drive/MyDrive/tacotron2')\n",
    "!git submodule init\n",
    "!git submodule update\n",
    "!pip install -q unidecode tensorboardX\n",
    "!pip install pypinyin\n",
    "!pip install librosa==0.9.2\n",
    "import librosa\n",
    "#@!pip install matplotlib==3.9.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "if os.getcwd() != '/content/drive/MyDrive/tacotron2':\n",
    "    os.chdir('/content/drive/MyDrive/tacotron2')\n",
    "import time\n",
    "import argparse\n",
    "import math\n",
    "from numpy import finfo\n",
    "import torch\n",
    "from distributed import apply_gradient_allreduce\n",
    "import torch.distributed as dist\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torch.utils.data import DataLoader\n",
    "from model import Tacotron2\n",
    "from data_utils import TextMelLoader, TextMelCollate\n",
    "from loss_function import Tacotron2Loss\n",
    "from logger import Tacotron2Logger\n",
    "from hparams import create_hparams\n",
    "import random\n",
    "import numpy as np\n",
    "import layers\n",
    "from utils import load_wav_to_torch, load_filepaths_and_text\n",
    "from text import text_to_sequence\n",
    "from math import e\n",
    "from tqdm.notebook import tqdm\n",
    "from distutils.dir_util import copy_tree\n",
    "import matplotlib.pylab as plt\n",
    "def download_from_google_drive(file_id, file_name):\n",
    "  !rm -f ./cookie\n",
    "  !curl -c ./cookie -s -L \"https://drive.google.com/uc?export=download&id={file_id}\" > /dev/null\n",
    "  confirm_text = !awk '/download/ {print $NF}' ./cookie\n",
    "  confirm_text = confirm_text[0]\n",
    "  !curl -Lb ./cookie \"https://drive.google.com/uc?export=download&confirm={confirm_text}&id={file_id}\" -o {file_name}\n",
    "def create_mels():\n",
    "    print(\"Generating Mels\")\n",
    "    stft = layers.TacotronSTFT(\n",
    "                hparams.filter_length, hparams.hop_length, hparams.win_length,\n",
    "                hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,\n",
    "                hparams.mel_fmax)\n",
    "    def save_mel(filename):\n",
    "        audio, sampling_rate = load_wav_to_torch(filename)\n",
    "        if sampling_rate != stft.sampling_rate:\n",
    "            raise ValueError(\"{} {} SR doesn't match target {} SR\".format(filename, \n",
    "                sampling_rate, stft.sampling_rate))\n",
    "        audio_norm = audio / hparams.max_wav_value\n",
    "        audio_norm = audio_norm.unsqueeze(0)\n",
    "        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)\n",
    "        melspec = stft.mel_spectrogram(audio_norm)\n",
    "        melspec = torch.squeeze(melspec, 0).cpu().numpy()\n",
    "        np.save(filename.replace('.wav', ''), melspec)\n",
    "    import glob\n",
    "    wavs = glob.glob('Samples/*/*.wav')\n",
    "    for i in tqdm(wavs):\n",
    "        save_mel(i)\n",
    "def reduce_tensor(tensor, n_gpus):\n",
    "    rt = tensor.clone()\n",
    "    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n",
    "    rt /= n_gpus\n",
    "    return rt\n",
    "def init_distributed(hparams, n_gpus, rank, group_name):\n",
    "    assert torch.cuda.is_available(), \"Distributed mode requires CUDA.\"\n",
    "    print(\"Initializing Distributed\")\n",
    "    torch.cuda.set_device(rank % torch.cuda.device_count())\n",
    "    dist.init_process_group(\n",
    "        backend=hparams.dist_backend, init_method=hparams.dist_url,\n",
    "        world_size=n_gpus, rank=rank, group_name=group_name)\n",
    "    print(\"Done initializing distributed\")\n",
    "def prepare_dataloaders(hparams):\n",
    "    trainset = TextMelLoader(hparams.training_files, hparams)\n",
    "    valset = TextMelLoader(hparams.validation_files, hparams)\n",
    "    collate_fn = TextMelCollate(hparams.n_frames_per_step)\n",
    "    if hparams.distributed_run:\n",
    "        train_sampler = DistributedSampler(trainset)\n",
    "        shuffle = False\n",
    "    else:\n",
    "        train_sampler = None\n",
    "        shuffle = True\n",
    "    train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,\n",
    "                              sampler=train_sampler,\n",
    "                              batch_size=hparams.batch_size, pin_memory=False,\n",
    "                              drop_last=True, collate_fn=collate_fn)\n",
    "    return train_loader, valset, collate_fn\n",
    "def prepare_directories_and_logger(output_directory, log_directory, rank):\n",
    "    if rank == 0:\n",
    "        if not os.path.isdir(output_directory):\n",
    "            os.makedirs(output_directory)\n",
    "            os.chmod(output_directory, 0o775)\n",
    "        logger = Tacotron2Logger(os.path.join(output_directory, log_directory))\n",
    "    else:\n",
    "        logger = None\n",
    "    return logger\n",
    "def load_model(hparams):\n",
    "    model = Tacotron2(hparams).cuda()\n",
    "    if hparams.fp16_run:\n",
    "        model.decoder.attention_layer.score_mask_value = finfo('float16').min\n",
    "    if hparams.distributed_run:\n",
    "        model = apply_gradient_allreduce(model)\n",
    "    return model\n",
    "def warm_start_model(checkpoint_path, model, ignore_layers):\n",
    "    assert os.path.isfile(checkpoint_path)\n",
    "    print(\"Warm starting model from checkpoint '{}'\".format(checkpoint_path))\n",
    "    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n",
    "    model_dict = checkpoint_dict['state_dict']\n",
    "    if len(ignore_layers) > 0:\n",
    "        model_dict = {k: v for k, v in model_dict.items()\n",
    "                      if k not in ignore_layers}\n",
    "        dummy_dict = model.state_dict()\n",
    "        dummy_dict.update(model_dict)\n",
    "        model_dict = dummy_dict\n",
    "    model.load_state_dict(model_dict)\n",
    "    return model\n",
    "def load_checkpoint(checkpoint_path, model, optimizer):\n",
    "    assert os.path.isfile(checkpoint_path)\n",
    "    print(\"Loading checkpoint '{}'\".format(checkpoint_path))\n",
    "    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n",
    "    model.load_state_dict(checkpoint_dict['state_dict'])\n",
    "    optimizer.load_state_dict(checkpoint_dict['optimizer'])\n",
    "    learning_rate = checkpoint_dict['learning_rate']\n",
    "    iteration = checkpoint_dict['iteration']\n",
    "    print(\"Loaded checkpoint '{}' from iteration {}\" .format(\n",
    "        checkpoint_path, iteration))\n",
    "    return model, optimizer, learning_rate, iteration\n",
    "def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):\n",
    "    print(\"Saving model and optimizer state at iteration {} to {}\".format(\n",
    "        iteration, filepath))\n",
    "    try:\n",
    "        torch.save({'iteration': iteration,\n",
    "                'state_dict': model.state_dict(),\n",
    "                'optimizer': optimizer.state_dict(),\n",
    "                'learning_rate': learning_rate}, filepath)\n",
    "    except KeyboardInterrupt:\n",
    "        print(\"interrupt received while saving, waiting for save to complete.\")\n",
    "        torch.save({'iteration': iteration,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),'learning_rate': learning_rate}, filepath)\n",
    "    print(\"Model Saved\")\n",
    "def plot_alignment(alignment, info=None):\n",
    "    fig, ax = plt.subplots(figsize=(int(alignment_graph_width/100), int(alignment_graph_height/100)))\n",
    "    im = ax.imshow(alignment, cmap='inferno', aspect='auto', origin='lower',\n",
    "                   interpolation='none')\n",
    "    ax.autoscale(enable=True, axis=\"y\", tight=True)\n",
    "    fig.colorbar(im, ax=ax)\n",
    "    xlabel = 'Decoder timestep'\n",
    "    if info is not None:\n",
    "        xlabel += '\\n\\n' + info\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel('Encoder timestep')\n",
    "    plt.tight_layout()\n",
    "    fig.canvas.draw()\n",
    "    plt.show()\n",
    "def validate(model, criterion, valset, iteration, batch_size, n_gpus,\n",
    "             collate_fn, logger, distributed_run, rank, epoch, start_eposh, learning_rate):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        val_sampler = DistributedSampler(valset) if distributed_run else None\n",
    "        val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,\n",
    "                                shuffle=False, batch_size=batch_size,\n",
    "                                pin_memory=False, collate_fn=collate_fn)\n",
    "        val_loss = 0.0\n",
    "        for i, batch in enumerate(val_loader):\n",
    "            x, y = model.parse_batch(batch)\n",
    "            y_pred = model(x)\n",
    "            loss = criterion(y_pred, y)\n",
    "            if distributed_run:\n",
    "                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()\n",
    "            else:\n",
    "                reduced_val_loss = loss.item()\n",
    "            val_loss += reduced_val_loss\n",
    "        val_loss = val_loss / (i + 1)\n",
    "    model.train()\n",
    "    if rank == 0:\n",
    "        print(\"Epoch: {} Validation loss {}: {:9f}  Time: {:.1f}m LR: {:.6f}\".format(epoch, iteration, val_loss,(time.perf_counter()-start_eposh)/60, learning_rate))\n",
    "        logger.log_validation(val_loss, model, y, y_pred, iteration)\n",
    "        if hparams.show_alignments:\n",
    "            _, mel_outputs, gate_outputs, alignments = y_pred\n",
    "            idx = random.randint(0, alignments.size(0) - 1)\n",
    "            plot_alignment(alignments[idx].data.cpu().numpy().T)\n",
    "def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,\n",
    "          rank, group_name, hparams, log_directory2):\n",
    "    if hparams.distributed_run:\n",
    "        init_distributed(hparams, n_gpus, rank, group_name)\n",
    "    torch.manual_seed(hparams.seed)\n",
    "    torch.cuda.manual_seed(hparams.seed)\n",
    "    model = load_model(hparams)\n",
    "    learning_rate = hparams.learning_rate\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,\n",
    "                                 weight_decay=hparams.weight_decay)\n",
    "    if hparams.fp16_run:\n",
    "        from apex import amp\n",
    "        model, optimizer = amp.initialize(\n",
    "            model, optimizer, opt_level='O2')\n",
    "    if hparams.distributed_run:\n",
    "        model = apply_gradient_allreduce(model)\n",
    "    criterion = Tacotron2Loss()\n",
    "    logger = prepare_directories_and_logger(\n",
    "        output_directory, log_directory, rank)\n",
    "    train_loader, valset, collate_fn = prepare_dataloaders(hparams)\n",
    "    iteration = 0\n",
    "    epoch_offset = 0\n",
    "    if checkpoint_path is not None and os.path.isfile(checkpoint_path):\n",
    "        if warm_start:\n",
    "            model = warm_start_model(\n",
    "                checkpoint_path, model, hparams.ignore_layers)\n",
    "        else:\n",
    "            model, optimizer, _learning_rate, iteration = load_checkpoint(\n",
    "                checkpoint_path, model, optimizer)\n",
    "            if hparams.use_saved_learning_rate:\n",
    "                learning_rate = _learning_rate\n",
    "            iteration += 1\n",
    "            epoch_offset = max(0, int(iteration / len(train_loader)))\n",
    "    else:\n",
    "      os.path.isfile(\"tacotron2_statedict.pt\")\n",
    "      model = warm_start_model(\"tacotron2_statedict.pt\", model, hparams.ignore_layers)\n",
    "    start_eposh = time.perf_counter()\n",
    "    learning_rate = 0.0\n",
    "    model.train()\n",
    "    is_overflow = False\n",
    "    for epoch in tqdm(range(epoch_offset, hparams.epochs)):\n",
    "        print(\"\\nStarting Epoch: {} Iteration: {}\".format(epoch, iteration))\n",
    "        start_eposh = time.perf_counter()\n",
    "        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
    "            start = time.perf_counter()\n",
    "            if iteration < hparams.decay_start: learning_rate = hparams.A_\n",
    "            else: iteration_adjusted = iteration - hparams.decay_start; learning_rate = (hparams.A_*(e**(-iteration_adjusted/hparams.B_))) + hparams.C_\n",
    "            learning_rate = max(hparams.min_learning_rate, learning_rate)\n",
    "            for param_group in optimizer.param_groups:\n",
    "                param_group['lr'] = learning_rate\n",
    "            model.zero_grad()\n",
    "            x, y = model.parse_batch(batch)\n",
    "            y_pred = model(x)\n",
    "            loss = criterion(y_pred, y)\n",
    "            if hparams.distributed_run:\n",
    "                reduced_loss = reduce_tensor(loss.data, n_gpus).item()\n",
    "            else:\n",
    "                reduced_loss = loss.item()\n",
    "            if hparams.fp16_run:\n",
    "                with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
    "                    scaled_loss.backward()\n",
    "            else:\n",
    "                loss.backward()\n",
    "            if hparams.fp16_run:\n",
    "                grad_norm = torch.nn.utils.clip_grad_norm_(\n",
    "                    amp.master_params(optimizer), hparams.grad_clip_thresh)\n",
    "                is_overflow = math.isnan(grad_norm)\n",
    "            else:\n",
    "                grad_norm = torch.nn.utils.clip_grad_norm_(\n",
    "                    model.parameters(), hparams.grad_clip_thresh)\n",
    "            optimizer.step()\n",
    "            if not is_overflow and rank == 0:\n",
    "                duration = time.perf_counter() - start\n",
    "                logger.log_training(\n",
    "                    reduced_loss, grad_norm, learning_rate, duration, iteration)\n",
    "            iteration += 1\n",
    "        validate(model, criterion, valset, iteration,\n",
    "                 hparams.batch_size, n_gpus, collate_fn, logger,\n",
    "                 hparams.distributed_run, rank, epoch, start_eposh, learning_rate)\n",
    "        save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)\n",
    "        if log_directory2 != None:\n",
    "            copy_tree(log_directory, log_directory2)\n",
    "def check_dataset(hparams):\n",
    "    from utils import load_wav_to_torch, load_filepaths_and_text\n",
    "    import os\n",
    "    import numpy as np\n",
    "    def check_arr(filelist_arr):\n",
    "        for i, file in enumerate(filelist_arr):\n",
    "            if len(file) > 2:\n",
    "                print(\"|\".join(file), \"\\nhas multiple '|', this may not be an error.\")\n",
    "            if hparams.load_mel_from_disk and '.wav' in file[0]:\n",
    "                print(\"[WARNING]\", file[0], \" in filelist while expecting .npy .\")\n",
    "            else:\n",
    "                if not hparams.load_mel_from_disk and '.npy' in file[0]:\n",
    "                    print(\"[WARNING]\", file[0], \" in filelist while expecting .wav .\")\n",
    "            if (not os.path.exists(file[0])):\n",
    "                print(\"|\".join(file), \"\\n[WARNING] does not exist.\")\n",
    "            if len(file[1]) < 3:\n",
    "                print(\"|\".join(file), \"\\n[info] has no/very little text.\")\n",
    "            if not ((file[1].strip())[-1] in r\"!?,.;:\"):\n",
    "                print(\"|\".join(file), \"\\n[info] has no ending punctuation.\")\n",
    "            mel_length = 1\n",
    "            if hparams.load_mel_from_disk and '.npy' in file[0]:\n",
    "                melspec = torch.from_numpy(np.load(file[0], allow_pickle=True))\n",
    "                mel_length = melspec.shape[1]\n",
    "            if mel_length == 0:\n",
    "                print(\"|\".join(file), \"\\n[WARNING] has 0 duration.\")\n",
    "    print(\"Checking Training Files\")\n",
    "    audiopaths_and_text = load_filepaths_and_text(hparams.training_files)\n",
    "    check_arr(audiopaths_and_text)\n",
    "    print(\"Checking Validation Files\")\n",
    "    audiopaths_and_text = load_filepaths_and_text(hparams.validation_files)\n",
    "    check_arr(audiopaths_and_text)\n",
    "    print(\"Finished Checking\")\n",
    "!sed -i -- 's,.wav|,.npy|,g' filelists/*.txt\n",
    "warm_start=Fals\n",
    "n_gpus=1\n",
    "rank=0\n",
    "group_name=None\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "model_filename = \"Samples\"\n",
    "hparams = create_hparams()\n",
    "hparams.training_files = \"/content/drive/MyDrive/tacotron2/filelists/training.txt\"\n",
    "hparams.validation_files = \"/content/drive/MyDrive/tacotron2/filelists/testing.txt\"\n",
    "hparams.batch_size = 30\n",
    "hparams.epochs = 600\n",
    "hparams.p_attention_dropout=0.1\n",
    "hparams.p_decoder_dropout=0.1\n",
    "hparams.decay_start = 15000     \n",
    "hparams.A_ = 5e-4                \n",
    "hparams.B_ = 8000                 \n",
    "hparams.C_ = 0                      \n",
    "hparams.min_learning_rate = 1e-5   \n",
    "generate_mels = True\n",
    "hparams.show_alignments = True\n",
    "alignment_graph_height = 600\n",
    "alignment_graph_width = 1000\n",
    "hparams.load_mel_from_disk = True\n",
    "hparams.ignore_layers = [] \n",
    "torch.backends.cudnn.enabled = hparams.cudnn_enabled\n",
    "torch.backends.cudnn.benchmark = hparams.cudnn_benchmark\n",
    "output_directory = '/content/drive/My Drive/colab/output' \n",
    "log_directory = '/content/tacotron2/logs'\n",
    "log_directory2 = '/content/drive/My Drive/colab/logs'\n",
    "checkpoint_path = output_directory+(r'/')+model_filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "if generate_mels:\n",
    "    create_mels()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "check_dataset(hparams)\n",
    "print('FP16 Run:', hparams.fp16_run)\n",
    "print('Dynamic Loss Scaling:', hparams.dynamic_loss_scaling)\n",
    "print('Distributed Run:', hparams.distributed_run)\n",
    "print('cuDNN Enabled:', hparams.cudnn_enabled)\n",
    "print('cuDNN Benchmark:', hparams.cudnn_benchmark)\n",
    "train(output_directory, log_directory, checkpoint_path,\n",
    "      warm_start, n_gpus, rank, group_name, hparams, log_directory2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "#@markdown Config:\n",
    "\n",
    "#@markdown Restart the code to apply any changes.\n",
    "\n",
    "#Add new characters here.\n",
    "#Universal HiFi-GAN (has some robotic noise): 1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\n",
    "Tacotron2_Model = '/content/drive/MyDrive/colab/outdir/Samples'#@param {type:\"string\"}\n",
    "TACOTRON2_ID = Tacotron2_Model\n",
    "HIFIGAN_ID = \"1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\"\n",
    "from pypinyin import lazy_pinyin,Style\n",
    "\n",
    "# Check if Initilized\n",
    "try:\n",
    "    initilized\n",
    "except NameError:\n",
    "    print(\"Setting up, please wait.\\n\")\n",
    "    !pip install tqdm -q\n",
    "    from tqdm.notebook import tqdm\n",
    "    with tqdm(total=5, leave=False) as pbar:\n",
    "        %tensorflow_version 2.x\n",
    "        import os\n",
    "        from os.path import exists, join, basename, splitext\n",
    "        !pip install gdown\n",
    "        git_repo_url = 'https://github.com/NVIDIA/tacotron2.git'\n",
    "        project_name = splitext(basename(git_repo_url))[0]\n",
    "        if not exists(project_name):\n",
    "            # clone and install\n",
    "            !git clone -q --recursive {git_repo_url}\n",
    "            !git clone -q --recursive https://github.com/SortAnon/hifi-gan\n",
    "            !pip install -q librosa unidecode\n",
    "        pbar.update(1) # downloaded TT2 and HiFi-GAN\n",
    "        import sys\n",
    "        sys.path.append('hifi-gan')\n",
    "        sys.path.append(project_name)\n",
    "        import time\n",
    "        import matplotlib\n",
    "        import matplotlib.pylab as plt\n",
    "        import gdown\n",
    "        d = 'https://drive.google.com/uc?id='\n",
    "        import IPython.display as ipd\n",
    "        import numpy as np\n",
    "        import torch\n",
    "        import json\n",
    "        from hparams import create_hparams\n",
    "        from model import Tacotron2\n",
    "        from layers import TacotronSTFT\n",
    "        from audio_processing import griffin_lim\n",
    "        from text import text_to_sequence\n",
    "        from env import AttrDict\n",
    "        from meldataset import MAX_WAV_VALUE\n",
    "        from models import Generator\n",
    "\n",
    "        pbar.update(1) # initialized Dependancies\n",
    "\n",
    "        graph_width = 900\n",
    "        graph_height = 360\n",
    "        def plot_data(data, figsize=(int(graph_width/100), int(graph_height/100))):\n",
    "            fig, axes = plt.subplots(1, len(data), figsize=figsize)\n",
    "            for i in range(len(data)):\n",
    "                axes[i].imshow(data[i], aspect='auto', origin='bottom', \n",
    "                            interpolation='none', cmap='inferno')\n",
    "            fig.canvas.draw()\n",
    "            plt.show()\n",
    "\n",
    "        # Setup Pronounciation Dictionary\n",
    "        !gdown --id '1E12g_sREdcH5vuZb44EZYX8JjGWQ9rRp'\n",
    "        thisdict = {}\n",
    "        for line in reversed((open('merged.dict.txt', \"r\").read()).splitlines()):\n",
    "            thisdict[(line.split(\" \",1))[0]] = (line.split(\" \",1))[1].strip()\n",
    "\n",
    "        pbar.update(1) # Downloaded and Set up Pronounciation Dictionary\n",
    "\n",
    "        def ARPA(text, punctuation=r\"!?,.;\", EOS_Token=True):\n",
    "            out = ''\n",
    "            for word_ in text.split(\" \"):\n",
    "                word=word_; end_chars = ''\n",
    "                while any(elem in word for elem in punctuation) and len(word) > 1:\n",
    "                    if word[-1] in punctuation: end_chars = word[-1] + end_chars; word = word[:-1]\n",
    "                    else: break\n",
    "                try:\n",
    "                    word_arpa = thisdict[word.upper()]\n",
    "                    word = \"{\" + str(word_arpa) + \"}\"\n",
    "                except KeyError: pass\n",
    "                out = (out + \" \" + word + end_chars).strip()\n",
    "            if EOS_Token and out[-1] != \";\": out += \";\"\n",
    "            return out\n",
    "\n",
    "        def get_hifigan(MODEL_ID):\n",
    "            # Download HiFi-GAN\n",
    "            hifigan_pretrained_model = 'hifimodel'\n",
    "            gdown.download(d+MODEL_ID, hifigan_pretrained_model, quiet=False)\n",
    "            if not exists(hifigan_pretrained_model):\n",
    "                raise Exception(\"HiFI-GAN model failed to download!\")\n",
    "\n",
    "            # Load HiFi-GAN\n",
    "            conf = os.path.join(\"hifi-gan\", \"config_v1.json\")\n",
    "            with open(conf) as f:\n",
    "                json_config = json.loads(f.read())\n",
    "            h = AttrDict(json_config)\n",
    "            torch.manual_seed(h.seed)\n",
    "            hifigan = Generator(h).to(torch.device(\"cuda\"))\n",
    "            state_dict_g = torch.load(hifigan_pretrained_model, map_location=torch.device(\"cuda\"))\n",
    "            hifigan.load_state_dict(state_dict_g[\"generator\"])\n",
    "            hifigan.eval()\n",
    "            hifigan.remove_weight_norm()\n",
    "            return hifigan, h\n",
    "\n",
    "        hifigan, h = get_hifigan(HIFIGAN_ID)\n",
    "        pbar.update(1) # Downloaded and Set up HiFi-GAN\n",
    "\n",
    "        def has_MMI(STATE_DICT):\n",
    "            return any(True for x in STATE_DICT.keys() if \"mi.\" in x)\n",
    "\n",
    "        def get_Tactron2(MODEL_ID):\n",
    "            # Download Tacotron2\n",
    "            tacotron2_pretrained_model = TACOTRON2_ID\n",
    "            if not exists(tacotron2_pretrained_model):\n",
    "                raise Exception(\"Tacotron2 model failed to download!\")\n",
    "            # Load Tacotron2 and Config\n",
    "            hparams = create_hparams()\n",
    "            hparams.sampling_rate = 22050\n",
    "            hparams.max_decoder_steps = 3000 # Max Duration\n",
    "            hparams.gate_threshold = 0.25 # Model must be 25% sure the clip is over before ending generation\n",
    "            model = Tacotron2(hparams)\n",
    "            state_dict = torch.load(tacotron2_pretrained_model)['state_dict']\n",
    "            if has_MMI(state_dict):\n",
    "                raise Exception(\"ERROR: This notebook does not currently support MMI models.\")\n",
    "            model.load_state_dict(state_dict)\n",
    "            _ = model.cuda().eval().half()\n",
    "            return model, hparams\n",
    "\n",
    "        model, hparams = get_Tactron2(TACOTRON2_ID)\n",
    "        previous_tt2_id = TACOTRON2_ID\n",
    "\n",
    "        pbar.update(1) # Downloaded and Set up Tacotron2\n",
    "\n",
    "        # Extra Info\n",
    "        def end_to_end_infer(text, pronounciation_dictionary, show_graphs):\n",
    "            for i in [x for x in text.split(\"\\n\") if len(x)]:\n",
    "                if not pronounciation_dictionary:\n",
    "                    if i[-1] != \";\": i=i+\";\" \n",
    "                else: i = ARPA(i)\n",
    "                with torch.no_grad(): # save VRAM by not including gradients\n",
    "                    sequence = np.array(text_to_sequence(i, ['english_cleaners']))[None, :]\n",
    "                    sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
    "                    mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n",
    "                    if show_graphs:\n",
    "                        plot_data((mel_outputs_postnet.float().data.cpu().numpy()[0],\n",
    "                                alignments.float().data.cpu().numpy()[0].T))\n",
    "                    y_g_hat = hifigan(mel_outputs_postnet.float())\n",
    "                    audio = y_g_hat.squeeze()\n",
    "                    audio = audio * MAX_WAV_VALUE\n",
    "                    print(\"\")\n",
    "                    ipd.display(ipd.Audio(audio.cpu().numpy().astype(\"int16\"), rate=hparams.sampling_rate))\n",
    "    from IPython.display import clear_output\n",
    "    clear_output()\n",
    "    initilized = \"Ready\"\n",
    "\n",
    "if previous_tt2_id != TACOTRON2_ID:\n",
    "    print(\"Updating Models\")\n",
    "    model, hparams = get_Tactron2(TACOTRON2_ID)\n",
    "    hifigan, h = get_hifigan(HIFIGAN_ID)\n",
    "    previous_tt2_id = TACOTRON2_ID\n",
    "\n",
    "pronounciation_dictionary = False #@param {type:\"boolean\"}\n",
    "# disables automatic ARPAbet conversion, useful for inputting your own ARPAbet pronounciations or just for testing\n",
    "show_graphs = True #@param {type:\"boolean\"}\n",
    "max_duration = 25 #this does nothing\n",
    "model.decoder.max_decoder_steps = 1000 #@param {type:\"integer\"}\n",
    "stop_threshold = 0.3 #@param {type:\"number\"}\n",
    "model.decoder.gate_threshold = stop_threshold\n",
    "\n",
    "#@markdown ---\n",
    "\n",
    "print(f\"Current Config:\\npronounciation_dictionary: {pronounciation_dictionary}\\nshow_graphs: {show_graphs}\\nmax_duration (in seconds): {max_duration}\\nstop_threshold: {stop_threshold}\\n\\n\")\n",
    "time.sleep(1)\n",
    "print(\"Enter/Paste your text.\")\n",
    "contents = []\n",
    "while True:\n",
    "    try:\n",
    "        print(\"-\"*50)\n",
    "        line = input()\n",
    "        if line != \"\":\n",
    "          line = \" \".join(lazy_pinyin(line, style=Style.TONE3))\n",
    "        print(line)\n",
    "        end_to_end_infer(line, pronounciation_dictionary, show_graphs)\n",
    "    except EOFError:\n",
    "        break\n",
    "    except KeyboardInterrupt:\n",
    "        print(\"Stopping...\")\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
